import {clone} from "../util";

export default class Mat {
    constructor (data) {
        this.data = data
        this.rows = this.data.length
        this.columns = this.data[0].length
        Object.defineProperty(this, 'T', {
            get :()=>{
                return Mat.create(this.rows, this.columns,(i, j) => {
                    return this.data[j][i]
                })
            }
        })
    }

    /**
     * 创建row行column列元素生成方法为fn的矩阵
     * @param row
     * @param column
     * @param fun
     */
    static create (row, column, fun) {
        if(!(fun instanceof Function)) {
            throw '参数：' + Object.prototype.toString.call(fun) + '不是函数'
        }
        let data = new Array()
        for(let i = 0; i < row; i++) {
            let row_data = new Array()
            for(let j = 0; j < column;j++) {
                row_data.push(fun(i, j))
            }
            data.push(row_data)
        }
        return new Mat(data)
    }

    /**
     *创建row行column列的零矩阵
     * @param row
     * @param column
     */
    static zeros (row,column) {
        return this.create(row, column, function () {
            return 0
        })
    }

    /**
     *创建row行column列元素都为1的矩阵
     * @param row
     * @param column
     */
    static ones (row, column) {
        return this.create(row, column, function() {
            return 1
        })
    }

    /**
     * 创建row行column列元素范围在range中的矩阵
     * @param row
     * @param column
     * @param range 数组，元素取值范围
     */
    static random(row, column, range = [0,1]) {
        if(!(range instanceof Array) || range.length != 2) {
            throw '参数:' + Object.prototype.toString.call(range) + '参数错误'
        }
        return this.create(row, column, function() {
            return range[0] + Math.random() * (range[1] - range[0])
        })
    }

    /**
     * 创建单位矩阵
     */
    static identity (row) {
        return this.create(row, row, function(i, j) {
            return i == j?1:0
        })
    }
    /**
     * 创建对角矩阵
     * @param data 对角矩阵上的元素
     */
    static diag (data) {
        return this.create(data.length, data.length, function(i, j) {
            return i == j?data[i]:0
        })
    }

    row (index) {
        return this.data[index]
    }
    column (index) {
        let column = []
        this.data.forEach((row)=>{
            column.push(row[index])
        })
        return column
    }
    lup () {
        let A = clone(this.data)
        let rows = A.length
        let pai = []
        for (let i = 0; i < rows; i++) {
            pai[i] = i
        }
        let k0 = 0
        for (let k = 0;k < rows - 1 ; k++) {
            let p = 0
            for (let i = k; i < rows; i++) {
                if (Math.abs(A[i][k]) > p) {
                    p = Math.abs(A[i][k])
                    k0 = i
                }
            }
            if (p == 0) {
                throw 'singular matrix'
            }
            let tmp = pai[k]
            pai[k] = pai[k0]
            pai[k0] = tmp
            for (let i = 0; i < rows;i++) {
                tmp = A[k][i]
                A[k][i] = A[k0][i]
                A[k0][i] = tmp
            }
            let u = A[k][k], l = 0
            for (let j = k + 1;j < rows;j++) {
                l = A[j][k] / u
                A[j][k] = l
                for (let i = k + 1; i < rows;i++) {
                    A[j][i] = A[j][i] - A[k][i]*l
                }
            }
        }
        let L = Mat.zeros(rows, rows),U = Mat.zeros(rows, rows)

        for (let i = 0;i < rows;i++) {
            for (let j = 0;j <=i; j++) {
                if(i != j) {
                    L.data[i][j] = A[i][j]
                }else {
                    L.data[i][j] = 1
                }
            }
            for (let k = i; k < rows; k++) {
                U.data[i][k] = A[i][k]
            }
        }
        return {L, U, pai}
    }

    lu () {
        let A = clone(this.data)
        let rows = A.length
        let L = Mat.zeros(rows, rows), U = Mat.zeros(rows, rows)
        for (let k = 0; k < rows; k++) {
            for (let i = k; i < rows; i++) {
                L.data[i][k] = A[i][k] == 0?0:A[i][k]/A[k][k]
                U.data[k][i] = A[k][i]
            }
            for (let i = k; i < rows;i++) {
                for (let j = k;j < rows; j++) {
                    A[i][j] = A[i][j] - L.data[i][k] * U.data[k][j]
                }
            }
        }
        return { L , U }
    }

    lup_solve (b) {
        let {L, U, pai} = this.lup()
        let x = [], y = []
        //正向替换
        for (let i = 0;i < this.rows;i++) {
            y.push(b[pai[i]])
            for(let j = 0;j < i; j++) {
                y[i] = y[i] - L.data[i][j] * y[j]
            }
            x[i] = 0
        }
        //反向替换
        for(let i = this.rows - 1;i >= 0; i--) {
            x[i] = y[i]
            for (let j = this.rows - 1;j > i; j--) {
                x[i] = x[i] - U.data[i][j] * x[j]
            }
            x[i] /= U.data[i][i]
        }
        return x
    }

    inv () {
        let rows = this.rows
        let ret = []
        for (let i = 0; i < rows; i++) {
            let b = []
            for (let j = 0;j < rows;j++) {
                if (i == j) {
                    b[j] = 1
                }else {
                    b[j] = 0
                }
            }
            let x = this.lup_solve(b)
            ret.push(x)
        }
        return new Mat(ret).T
    }

    /**
     * 矩阵加法运算
     */
    static __add__ (mat1, mat2) {
        if (mat1.rows != mat2.rows || mat1.columns != mat2.columns) {
            throw '维度（' + mat1.rows + ', ' +  mat1.columns + '）不能与维度 (' + mat1.rows + ', ' +  mat1.columns + ')相加'
        }
        return this.create(mat1.rows, mat1.columns, (i ,j) => {
            return mat1.data[i][j] + mat2.data[i][j]
        })
    }

    /**
     * 矩阵减法运算
     */
    static __plus__(mat1, mat2) {
        if (mat1.rows != mat2.rows || mat1.columns != mat2.columns) {
            throw '维度（' + mat1.rows + ', ' +  mat1.columns + '）不能与维度 (' + mat1.rows + ', ' +  mat1.columns + ')相减'
        }
        return this.create(mat1.rows, mat1.columns, (i ,j) => {
            return mat1.data[i][j] - mat2.data[i][j]
        })
    }

    /**
     * 矩阵乘法运算
     */
    static __multiply__(mat1, mat2) {
        if (mat1.columns != mat2.rows) {
            throw '维度（' + mat1.rows + ', ' +  mat1.columns + '）不能与维度 (' + mat1.rows + ', ' +  mat1.columns + ')相减'
        }
        return this.create(mat1.rows, mat2.columns, (i, j) => {
            let tmp = 0,
                row = mat1.row(i),
                column = mat2.column(j)
            for(let key in row) {
                tmp += row[key] * column[key]
            }
            return tmp
        })
    }

    static __and__(mat1, mat2) {
        if (mat1.rows != mat2.rows || mat1.columns != mat2.columns) {
            throw '维度（' + mat1.rows + ', ' +  mat1.columns + '）不能与维度 (' + mat1.rows + ', ' +  mat1.columns + ')相减'
        }
        return this.create(mat1.rows, mat1.columns, (i, j) => {
            return mat1.data[i][j] * mat2.data[i][j]
        })
    }

    static __mod__ (mat1, mat2) {
        if (mat1.rows != mat2.rows || mat1.columns != mat2.columns) {
            throw '维度（' + mat1.rows + ', ' +  mat1.columns + '）不能与维度 (' + mat1.rows + ', ' +  mat1.columns + ')相减'
        }
        return this.create(mat1.rows, mat1.columns, (i, j) => {
            return mat2.data[i][j] == 0?'-inf':mat1.data[i][j] / mat2.data[i][j]
        })
    }

    /**
     * 矩阵除法
     * mat1/mat2 等价于 mat2的逆矩阵乘以mat1
     */
    static __divide__ (mat1, mat2) {
        if (mat1.columns != mat2.rows) {
            throw '维度（' + mat1.rows + ', ' +  mat1.columns + '）不能与维度 (' + mat1.rows + ', ' +  mat1.columns + ')相减'
        }
        return this.__multiply__(mat2.inv(), mat1)
    }

}